import argparse
import time
import numpy as np
import networkx as nx
import tensorflow as tf
from tensorflow.keras import layers
from dgl import DGLGraph
from dgl.data import register_data_args, load_data
from dgi import DGI, Classifier


def evaluate(model, features, labels, mask):
    logits = model(features, training=False)
    logits = logits[mask]
    labels = labels[mask]
    indices = tf.math.argmax(logits, axis=1)
    acc = tf.reduce_mean(tf.cast(indices == labels, dtype=tf.float32))
    return acc.numpy().item()


def main(args):
    # load and preprocess dataset
    data = load_data(args)
    if args.gpu < 0:
        device = "/cpu:0"
    else:
        device = "/gpu:{}".format(args.gpu)
    with tf.device(device):
        features = tf.convert_to_tensor(data.features, dtype=tf.float32)
        labels = tf.convert_to_tensor(data.labels, dtype=tf.int64)
        train_mask = tf.convert_to_tensor(data.train_mask, dtype=tf.bool)
        val_mask = tf.convert_to_tensor(data.val_mask, dtype=tf.bool)
        test_mask = tf.convert_to_tensor(data.test_mask, dtype=tf.bool)
        in_feats = features.shape[1]
        n_classes = data.num_labels
        n_edges = data.graph.number_of_edges()

        # graph preprocess
        g = data.graph
        # add self loop
        if args.self_loop:
            g.remove_edges_from(nx.selfloop_edges(g))
            g.add_edges_from(zip(g.nodes(), g.nodes()))
        g = DGLGraph(g)
        n_edges = g.number_of_edges()

        # create DGI model
        dgi = DGI(g,
                  in_feats,
                  args.n_hidden,
                  args.n_layers,
                  tf.keras.layers.PReLU(alpha_initializer=tf.constant_initializer(0.25)),
                  args.dropout)

        dgi_optimizer = tf.keras.optimizers.Adam(
            learning_rate=args.dgi_lr)

        # train deep graph infomax
        cnt_wait = 0
        best = 1e9
        best_t = 0
        dur = []
        for epoch in range(args.n_dgi_epochs):
            if epoch >= 3:
                t0 = time.time()

            with tf.GradientTape() as tape:
                loss = dgi(features)
                # Manually Weight Decay
                # We found Tensorflow has a different implementation on weight decay 
                # of Adam(W) optimizer with PyTorch. And this results in worse results.
                # Manually adding weights to the loss to do weight decay solves this problem.
                for weight in dgi.trainable_weights:
                    loss = loss + \
                        args.weight_decay * tf.nn.l2_loss(weight)
                grads = tape.gradient(loss, dgi.trainable_weights)
                dgi_optimizer.apply_gradients(zip(grads, dgi.trainable_weights))

            if loss < best:
                best = loss
                best_t = epoch
                cnt_wait = 0
                dgi.save_weights('best_dgi.pkl')
            else:
                cnt_wait += 1

            if cnt_wait == args.patience:
                print('Early stopping!')
                break

            if epoch >= 3:
                dur.append(time.time() - t0)

            print("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | "
                  "ETputs(KTEPS) {:.2f}".format(epoch, np.mean(dur), loss.numpy().item(),
                                                n_edges / np.mean(dur) / 1000))

        # create classifier model
        classifier = Classifier(args.n_hidden, n_classes)

        classifier_optimizer = tf.keras.optimizers.Adam(learning_rate=args.classifier_lr)

        # train classifier
        print('Loading {}th epoch'.format(best_t))
        dgi.load_weights('best_dgi.pkl')
        embeds = dgi.encoder(features, corrupt=False)
        embeds = tf.stop_gradient(embeds)
        dur = []
        loss_fcn = tf.keras.losses.SparseCategoricalCrossentropy(
            from_logits=True)
        for epoch in range(args.n_classifier_epochs):
            if epoch >= 3:
                t0 = time.time()
            with tf.GradientTape() as tape:
                preds = classifier(embeds)
                loss = loss_fcn(labels[train_mask], preds[train_mask])
                # Manually Weight Decay
                # We found Tensorflow has a different implementation on weight decay 
                # of Adam(W) optimizer with PyTorch. And this results in worse results.
                # Manually adding weights to the loss to do weight decay solves this problem.
                # In original code, there's no weight decay applied in this part 
                # link: https://github.com/PetarV-/DGI/blob/master/execute.py#L121
                # for weight in classifier.trainable_weights:
                #     loss = loss + \
                #         args.weight_decay * tf.nn.l2_loss(weight)
                grads = tape.gradient(loss, classifier.trainable_weights)
                classifier_optimizer.apply_gradients(zip(grads, classifier.trainable_weights))
            if epoch >= 3:
                dur.append(time.time() - t0)

            acc = evaluate(classifier, embeds, labels, val_mask)
            print("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | "
                  "ETputs(KTEPS) {:.2f}".format(epoch, np.mean(dur), loss.numpy().item(),
                                                acc, n_edges / np.mean(dur) / 1000))

        print()
        acc = evaluate(classifier, embeds, labels, test_mask)
        print("Test Accuracy {:.4f}".format(acc))


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='DGI')
    register_data_args(parser)
    parser.add_argument("--dropout", type=float, default=0.,
                        help="dropout probability")
    parser.add_argument("--gpu", type=int, default=-1,
                        help="gpu")
    parser.add_argument("--dgi-lr", type=float, default=1e-3,
                        help="dgi learning rate")
    parser.add_argument("--classifier-lr", type=float, default=1e-2,
                        help="classifier learning rate")
    parser.add_argument("--n-dgi-epochs", type=int, default=300,
                        help="number of training epochs")
    parser.add_argument("--n-classifier-epochs", type=int, default=300,
                        help="number of training epochs")
    parser.add_argument("--n-hidden", type=int, default=512,
                        help="number of hidden gcn units")
    parser.add_argument("--n-layers", type=int, default=1,
                        help="number of hidden gcn layers")
    parser.add_argument("--weight-decay", type=float, default=0.,
                        help="Weight for L2 loss")
    parser.add_argument("--patience", type=int, default=20,
                        help="early stop patience condition")
    parser.add_argument("--self-loop", action='store_true',
                        help="graph self-loop (default=False)")
    parser.set_defaults(self_loop=False)
    args = parser.parse_args()
    print(args)

    main(args)
